1. Load Required Libraries

# Bioconductor packages
if (!requireNamespace("BiocManager", quietly = TRUE)) {
  install.packages("BiocManager", repos = "https://cloud.r-project.org")
}

if (!requireNamespace("Seurat", quietly = TRUE)) {
  install.packages('Seurat', repos = c('https://satijalab.r-universe.dev')) # It is better to use Seurat v4 here
}
library(Seurat)

if (!requireNamespace("BiocParallel", quietly = TRUE)) {
  BiocManager::install("BiocParallel", ask = FALSE, update = FALSE)
}
library(BiocParallel)

if (!requireNamespace("scDblFinder", quietly = TRUE)) {
  BiocManager::install("scDblFinder", ask = FALSE, update = FALSE)
}
library(scDblFinder)

if (!requireNamespace("SingleCellExperiment", quietly = TRUE)) {
  BiocManager::install("SingleCellExperiment", ask = FALSE, update = FALSE)
}
library(SingleCellExperiment)

if (!requireNamespace("scater", quietly = TRUE)) {
  BiocManager::install("scater", ask = FALSE, update = FALSE)
}
library(scater)

if (!requireNamespace("scran", quietly = TRUE)) {
  BiocManager::install("scran", ask = FALSE, update = FALSE)
}
library(scran)

if (!requireNamespace("splatter", quietly = TRUE)) {
  BiocManager::install("splatter", ask = FALSE, update = FALSE)
}
library(splatter)

# CRAN packages
if (!requireNamespace("plotly", quietly = TRUE)) {
  install.packages("plotly", repos = "https://cloud.r-project.org")
}
library(plotly)

if (!requireNamespace("PRROC", quietly = TRUE)) {
  install.packages("PRROC", repos = "https://cloud.r-project.org")
}
library(PRROC)

if (!requireNamespace("mclust", quietly = TRUE)) {
  install.packages("mclust", repos = "https://cloud.r-project.org")
}
library(mclust)

if (!requireNamespace("igraph", quietly = TRUE)) {
  install.packages("igraph", repos = "https://cloud.r-project.org")
}
library(igraph)

if (!requireNamespace("cowplot", quietly = TRUE)) {
  install.packages("cowplot", repos = "https://cloud.r-project.org")
}
library(cowplot)

theme_set(theme_minimal())
source("../misc.R")

suppFigN <- 0L
suppFig <- function(increment=NULL, nb=NULL){
  if(is.null(increment)) increment <- is.null(nb)
  if(!isFALSE(increment)){
    assign("suppFigN", suppFigN+as.integer(increment), envir=.GlobalEnv)
    if(increment>1 && is.null(nb)) nb <- paste0(suppFigN-increment+1,"-",suppFigN)
  }
  if(is.null(nb)) nb <- suppFigN
  if(nb<0) nb <- suppFigN+nb
  paste0("Extended Data - Figure ", nb)
}

2. Import and Convert Data

sc_data <- readRDS("./dataset/pbmc.rds")

sce <- as.SingleCellExperiment(sc_data)

3. Data Preprocessing

# --------------------------
# Core preprocessing steps
# --------------------------

# 1. Quality control: filter low-quality cells
sce <- addPerCellQC(sce, subsets = list(Mito = grep("^MT-", rownames(sce))))
qc_metrics <- quickPerCellQC(
  colData(sce),
  percent_subsets = c("subsets_Mito_percent"),
  batch = sce$batch
)
sce <- sce[, qc_metrics$discard == FALSE]

# 2. Normalization using scran's method
set.seed(123)
clusters <- quickCluster(sce, method = "igraph")
sce      <- computeSumFactors(sce, clusters = clusters)
sce      <- logNormCounts(sce)

# 3. High‑variance gene selection to improve doublet detection sensitivity
dec <- modelGeneVar(sce, block = sce$batch)
hvg <- getTopHVGs(dec, n = 3000)
sce <- sce[hvg, ]

# 4. Dimensionality reduction (optimized PCA parameters)
set.seed(123)
sce <- runPCA(sce, ncomponents = 50, scale = TRUE)

4. Doublet Calling

# ------------------------------------------------
# Use Seurat clusters to guide doublet calling
# ------------------------------------------------

# Assume you already have Seurat clusters in sce$seurat_clusters
sce <- scDblFinder(sce, clusters = "seurat_clusters")
## 9 clusters
## Creating ~2007 artificial doublets...
## Dimensional reduction
## Evaluating kNN...
## Training model...
## iter=0, 103 cells excluded from training.
## iter=1, 81 cells excluded from training.
## iter=2, 76 cells excluded from training.
## Threshold found:0.417
## 74 (3%) doublets called
# Print how many doublets were detected
table(sce$scDblFinder.class)
## 
## singlet doublet 
##    2434      74

5. Summary of Doublet QC

cat("=== Doublet QC Summary ===\n")
## === Doublet QC Summary ===
cat("Number of doublets detected:", sum(sce$scDblFinder.class == "doublet"), "\n")
## Number of doublets detected: 74
cat("Doublet rate:", round(mean(sce$scDblFinder.class == "doublet") * 100, 2), "%\n")
## Doublet rate: 2.95 %
cat("Cluster with highest doublet count:",
    names(which.max(table(sce$scDblFinder.cluster[sce$scDblFinder.class == "doublet"]))),
    "\n\n")
## Cluster with highest doublet count: 2
summary(sce$scDblFinder.score)
##      Min.   1st Qu.    Median      Mean   3rd Qu.      Max. 
## 0.0002177 0.0018206 0.0040550 0.0395746 0.0107350 0.9996960

6. 3D PCA Visualization

pca_coords <- as.data.frame(reducedDim(sce, "PCA")[, 1:3])
plot_ly(
  pca_coords,
  x     = ~PC1, y = ~PC2, z = ~PC3,
  color = ~sce$scDblFinder.class, colors = c("grey","red"),
  marker=list(size=3)
) %>% add_markers() %>%
  layout(title="3D PCA: Doublet vs. Singlet")

7. ROC Curve with SNP‑based Ground Truth

e <- readRDS("./dataset/GSM2560248_noAmbiguous.processed.CD.rds")
# proportion homotypic doublets: these will be called as false negatives
prop.homotypic <- propHomotypic(e$scDblFinder.cluster)
# proportion within/intra-individual doublets: these will be called as false positives
prop.intraind <- propHomotypic(e$ind)
th <- min(e$scDblFinder.score[e$scDblFinder.class=="doublet"])
w <- which(e$scDblFinder.score>=th)
x <- as.integer(e$multiplets[w]=="doublet")[order(e$scDblFinder.score[w])]
d <- data.frame(x=sum(!x)/length(x),
                y=sum(x)/sum(e$multiplets=="doublet"))
d
##           x         y
## 1 0.3254371 0.7484355
p2 <- plotROCs(list(score=e$scDblFinder.score), e$multiplets=="doublet", fdr=TRUE, 
               prop.wrong.neg=prop.intraind, prop.wrong.pos=prop.homotypic, addNull=TRUE,
               showLegend=FALSE) + scale_color_manual(values=c("score"="darkviolet")) +
  geom_point(data=d, aes(x,y), size=3.5) + 
  labs(x="FDR (SNP-based)", y="TPR (SNP-based)")

p <- plot_grid( dblTypesScheme(), 
                p2 + theme(legend.position="none"), 
                scale=0.95, labels="AUTO")

# pdf("Figure4.pdf", width=9, height=4)
# p
# dev.off()
p

8. Simulation Benchmark: Precision / Recall / F1

set.seed(123)
rates    <- c(0.001,0.005,0.01,0.05)
metrics  <- data.frame(rate=rates, Precision=NA, Recall=NA, F1=NA)
sce_1pct <- NULL

for(i in seq_along(rates)) {
  rate      <- rates[i]; n_singlet <- 3000
  n_doublet <- round(rate*n_singlet)
  
  # simulate
  params      <- newSplatParams(batchCells=n_singlet)
  sce0        <- splatSimulate(params, method="groups", group.prob=c(0.5,0.5), verbose=FALSE)
  mat0        <- counts(sce0)
  idxs        <- sample(n_singlet,2*n_doublet,replace=TRUE)
  mat_dbl     <- mat0[,idxs[1:n_doublet]] + mat0[,idxs[(n_doublet+1):(2*n_doublet)]]
  sce_all     <- SingleCellExperiment(assays=list(counts=cbind(mat0,mat_dbl)))
  sce_all$isDoublet <- rep(c(FALSE,TRUE),c(n_singlet,n_doublet))
  
  # detect
  sce_all <- logNormCounts(sce_all)
  sce_all <- scDblFinder(sce_all)
  
  # metrics
  truth <- sce_all$isDoublet
  pred  <- sce_all$scDblFinder.class=="doublet"
  TP <- sum(pred & truth); FP <- sum(pred & !truth); FN <- sum(!pred & truth)
  prec <- TP/(TP+FP); rec <- TP/(TP+FN); f1 <- 2*prec*rec/(prec+rec)
  metrics[i,2:4] <- c(prec, rec, f1)
  if(rate==0.01) sce_1pct <- sce_all
}
## Creating ~2403 artificial doublets...
## Dimensional reduction
## Evaluating kNN...
## Training model...
## iter=0, 169 cells excluded from training.
## iter=1, 214 cells excluded from training.
## iter=2, 239 cells excluded from training.
## Threshold found:0.389
## 69 (2.3%) doublets called
## Creating ~2412 artificial doublets...
## Dimensional reduction
## Evaluating kNN...
## Training model...
## iter=0, 197 cells excluded from training.
## iter=1, 223 cells excluded from training.
## iter=2, 249 cells excluded from training.
## Threshold found:0.304
## 82 (2.7%) doublets called
## Creating ~2424 artificial doublets...
## Dimensional reduction
## Evaluating kNN...
## Training model...
## iter=0, 210 cells excluded from training.
## iter=1, 242 cells excluded from training.
## iter=2, 277 cells excluded from training.
## Threshold found:0.25
## 100 (3.3%) doublets called
## Creating ~2520 artificial doublets...
## Dimensional reduction
## Evaluating kNN...
## Training model...
## iter=0, 329 cells excluded from training.
## iter=1, 379 cells excluded from training.
## iter=2, 399 cells excluded from training.
## Threshold found:0.385
## 214 (6.8%) doublets called
print(metrics)
##    rate  Precision    Recall         F1
## 1 0.001 0.04347826 1.0000000 0.08333333
## 2 0.005 0.17073171 0.9333333 0.28865979
## 3 0.010 0.27000000 0.9000000 0.41538462
## 4 0.050 0.68691589 0.9800000 0.80769231
# PR curve at 1%
scores <- sce_1pct$scDblFinder.score; truth <- sce_1pct$isDoublet
pr <- pr.curve(scores.class0=scores[truth], scores.class1=scores[!truth], curve=TRUE)
plot(pr, main="PR Curve (1% doublets)")

9. Simulation Benchmark: Clustering Impact (ARI)

set.seed(123)
rates <- c(0.01,0.03,0.05,0.10)
res   <- data.frame(rate=rates, ARI_before=NA, ARI_after=NA, ARI_change_pct=NA)

for(i in seq_along(rates)) {
  rate      <- rates[i]; n_singlet <- 5000
  n_doublet <- round(rate*n_singlet)
  params   <- newSplatParams(batchCells=n_singlet)
  sce0     <- splatSimulate(params, method="groups", group.prob=c(0.5,0.5), verbose=FALSE)
  mat0     <- counts(sce0); true_cl <- colData(sce0)$Group
  idxs     <- sample(n_singlet,2*n_doublet,replace=TRUE)
  mat_dbl  <- mat0[,idxs[1:n_doublet]] + mat0[,idxs[(n_doublet+1):(2*n_doublet)]]
  sce_all  <- SingleCellExperiment(assays=list(counts=cbind(mat0,mat_dbl)))
  sce_all$isDoublet <- rep(c(FALSE,TRUE),c(n_singlet,n_doublet))
  sce_all <- logNormCounts(sce_all); sce_all <- scDblFinder(sce_all)
  
  # ARI before filtering
  sce_all <- runPCA(sce_all,ncomponents=10)
  g_all   <- buildSNNGraph(sce_all,use.dimred="PCA",k=10)
  cl_all  <- cluster_walktrap(g_all)$membership
  ari_b   <- adjustedRandIndex(cl_all[1:n_singlet], true_cl)
  
  # ARI after filtering
  keep    <- which(sce_all$scDblFinder.class!="doublet")
  sce_ft  <- sce_all[,keep]
  sce_ft  <- runPCA(sce_ft,ncomponents=10)
  g_ft    <- buildSNNGraph(sce_ft,use.dimred="PCA",k=10)
  cl_ft   <- cluster_walktrap(g_ft)$membership
  sing_idx<- intersect(keep, seq_len(n_singlet))
  names(cl_ft) <- colnames(sce_ft)
  ari_a   <- adjustedRandIndex(cl_ft[colnames(sce_all)[sing_idx]], true_cl[sing_idx])
  
  res[i,2:4] <- c(ari_b, ari_a, (ari_b-ari_a)/ari_b*100)
}
## Creating ~4040 artificial doublets...
## Dimensional reduction
## Evaluating kNN...
## Training model...
## iter=0, 370 cells excluded from training.
## iter=1, 461 cells excluded from training.
## iter=2, 478 cells excluded from training.
## Threshold found:0.351
## 148 (2.9%) doublets called
## Creating ~4120 artificial doublets...
## Dimensional reduction
## Evaluating kNN...
## Training model...
## iter=0, 547 cells excluded from training.
## iter=1, 583 cells excluded from training.
## iter=2, 612 cells excluded from training.
## Threshold found:0.386
## 265 (5.1%) doublets called
## Creating ~4200 artificial doublets...
## Dimensional reduction
## Evaluating kNN...
## Training model...
## iter=0, 651 cells excluded from training.
## iter=1, 696 cells excluded from training.
## iter=2, 715 cells excluded from training.
## Threshold found:0.428
## 338 (6.4%) doublets called
## Creating ~4400 artificial doublets...
## Dimensional reduction
## Evaluating kNN...
## Training model...
## iter=0, 779 cells excluded from training.
## iter=1, 842 cells excluded from training.
## iter=2, 880 cells excluded from training.
## Threshold found:0.668
## 538 (9.8%) doublets called
print(res)
##   rate ARI_before ARI_after ARI_change_pct
## 1 0.01  0.5235556 0.5427335      -3.663023
## 2 0.03  0.5184449 0.5482957      -5.757769
## 3 0.05  0.7574871 0.5058154      33.224546
## 4 0.10  0.8748746 0.5326731      39.114350